from torch.utils.data import TensorDataset, Dataset

class MyDataSet(Dataset):
    def __init__(self, m, n):
        self.X = m
        self.y = n
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
